{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ebUMqK9mGIDm"
   },
   "source": [
    "## The basics: interactive NumPy on GPU and TPU\n",
    "\n",
    "---\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "27TqNtiQF97X"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cRWoxSCNGU4o"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)\n",
    "key, subkey = random.split(key)\n",
    "x = random.normal(key, (5000, 5000))\n",
    "\n",
    "print(x.shape)\n",
    "print(x.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "diPllsvgGfSA"
   },
   "outputs": [],
   "source": [
    "y = jnp.dot(x, x)\n",
    "print(y[0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8-psauxnGiRk"
   },
   "outputs": [],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-2FMQ8UeoTJ8"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(jnp.dot(x, x.T))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "z4VX5PkMHJIu"
   },
   "outputs": [],
   "source": [
    "print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ORZ9Odu85BCJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "x_cpu = np.array(x)\n",
    "%timeit -n 5 -r 2 np.dot(x_cpu, x_cpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5BKh0eeAGvO5"
   },
   "outputs": [],
   "source": [
    "%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fm4Q2zpFHUAu"
   },
   "source": [
    "## Automatic differentiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MCIQbyUYHWn1"
   },
   "outputs": [],
   "source": [
    "from jax import grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kfqZpKYsHo4j"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  if x > 0:\n",
    "    return 2 * x ** 3\n",
    "  else:\n",
    "    return 3 * x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K_26_odPHqLJ"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)\n",
    "x = random.normal(key, ())\n",
    "\n",
    "print(grad(f)(x))\n",
    "print(grad(f)(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "q5V3A6loHrhS"
   },
   "outputs": [],
   "source": [
    "print(grad(grad(f))(-x))\n",
    "print(grad(grad(grad(f)))(-x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bmxAPFC0I8b0"
   },
   "source": [
    "Other JAX autodiff highlights:\n",
    "\n",
    "*   Forward- and reverse-mode, totally composable\n",
    "*   Fast Jacobians and Hessians\n",
    "*   Complex number support (holomorphic and non-holomorphic)\n",
    "*   Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
    "\n",
    "\n",
    "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TRkxaVLJKNre"
   },
   "source": [
    "## End-to-end compilation with XLA with `jit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bKo4rX9-KSW7"
   },
   "outputs": [],
   "source": [
    "from jax import jit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "94iIgZSfKWh8"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)\n",
    "x = random.normal(key, (5000, 5000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ybuz8Ag9KXMd"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = x\n",
    "  for _ in range(10):\n",
    "    y = y - 0.1 * y + 3.\n",
    "  return y[:100, :100]\n",
    "\n",
    "f(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y9dx5ifSKaGJ"
   },
   "outputs": [],
   "source": [
    "g = jit(f)\n",
    "g(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UtsS67BvKYkC"
   },
   "outputs": [],
   "source": [
    "%timeit f(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-vfcaSo9KbvR"
   },
   "outputs": [],
   "source": [
    "%timeit -n 100 g(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "E3BQF1_AKeLn"
   },
   "outputs": [],
   "source": [
    "grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Tmf1NT2Wqv5p"
   },
   "source": [
    "## Parallelization over multiple accelerators with pmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "t6RRAFn1CEln"
   },
   "outputs": [],
   "source": [
    "jax.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tEK1I6Duqunw"
   },
   "outputs": [],
   "source": [
    "from jax import pmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "S-iCNfeGqzkY"
   },
   "outputs": [],
   "source": [
    "y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xgutf5JPP3wi"
   },
   "outputs": [],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uvDL2_bCq7kq"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xf5N9ZRirJhL"
   },
   "source": [
    "### Collective communication operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from jax.lax import psum\n",
    "\n",
    "@partial(pmap, axis_name='i')\n",
    "def f(x):\n",
    "  total = psum(x, 'i')\n",
    "  return x / total, total\n",
    "\n",
    "normalized, total = f(jnp.arange(8.))\n",
    "\n",
    "print(f\"normalized:\\n{normalized}\\n\")\n",
    "print(\"total:\", total)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jC-KIMQ1q-lK"
   },
   "source": [
    "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatic parallelization with sharded_jit (new!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental import sharded_jit, PartitionSpec as P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import lax\n",
    "\n",
    "conv = lambda image, kernel: lax.conv(image, kernel, (1, 1), 'SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = jnp.ones((1, 8, 2000, 1000)).astype(np.float32)\n",
    "kernel = jnp.array(np.random.random((8, 8, 5, 5)).astype(np.float32))\n",
    "\n",
    "np.set_printoptions(edgeitems=1)\n",
    "conv(image, kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit conv(image, kernel).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_partitions = P(1, 1, 4, 2)\n",
    "sharded_conv = sharded_jit(conv, \n",
    "                           in_parts=(image_partitions, None), \n",
    "                           out_parts=image_partitions)\n",
    "\n",
    "sharded_conv(image, kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 sharded_conv(image, kernel).block_until_ready()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [
    "AvXl1WDPKjmV"
   ],
   "name": "JAX demo.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
